import torch

input = torch.randint(0, 10, (4, 8))
print(input.shape)
print(input)

input = input.unsqueeze(1).unsqueeze(1)
print(input.shape)
print(input)
